1
要素ごとの演算を超えて、タイル化された行列演算への移行
AI023Lesson 9
00:00

これまでの授業では、 要素ごとの演算 (例えば行列に対する基本的なReLU)。これらは メモリ制限型 GPUがデータをHBMからレジスタへ移動する時間の方が、計算を行う時間よりも長いためです。

1. GEMMの重要性

一般行列乗算(GEMM)は計算量のオーダーが$O(N^3)$である一方、メモリアクセス量は$O(N^2)$で済みます。これにより、膨大な算術演算スループットによってメモリ遅延を隠蔽できるため、大規模言語モデル(LLM)の「心臓部」ともいえます。

2. 2次元メモリ表現

物理的なメモリは1次元です。2次元テンソルを表現するには ストライドを使用します。生産環境でのよくある落とし穴は テンソルが連続していると仮定することです。ポインタの計算で行と列のストライドを混同すると、「ゴーストデータ」にアクセスしたり、メモリ違反を引き起こすことがあります。

3. タイル化の一般化

Tritonは、要素ごとの論理を 単一のポインタ から ポインタのブロックへとシフトすることで一般化しています。2次元タイル(例:$16 \times 16$)を使用することで、 データ再利用 高速なSRAMで効果的に活用でき、バイアス加算や活性化関数などの結合演算のためにデータを「ホット」な状態に保ち、グローバルメモリへの書き戻し前に処理できます。

1次元線形レイアウト2次元タイルレイアウト
main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>